{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Inspired by https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3%20-%20Neural%20Networks/recurrent_network.py\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow.contrib import rnn\n",
    "from tensorflow.examples.tutorials.mnist import input_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# configuration\n",
    "#                        O * W + b -> 10 labels for each image, O[? 28], W[28 10], B[10]\n",
    "#                       ^ (O: output 28 vec from 28 vec input)\n",
    "#                       |\n",
    "#      +-+  +-+       +--+\n",
    "#      |1|->|2|-> ... |28| time_step_size = 28\n",
    "#      +-+  +-+       +--+\n",
    "#       ^    ^    ...  ^\n",
    "#       |    |         |\n",
    "# img1:[28] [28]  ... [28]\n",
    "# img2:[28] [28]  ... [28]\n",
    "# img3:[28] [28]  ... [28]\n",
    "# ...\n",
    "# img128 or img256 (batch_size or test_size 256)\n",
    "#      each input size = input_vec_size=lstm_size=28\n",
    "\n",
    "# configuration variables\n",
    "input_vec_size = lstm_size = 28\n",
    "time_step_size = 28\n",
    "\n",
    "batch_size = 128\n",
    "test_size = 256\n",
    "\n",
    "def init_weights(shape):\n",
    "    return tf.Variable(tf.random_normal(shape, stddev=0.01))\n",
    "\n",
    "def model(X, W, B, lstm_size):\n",
    "    # X, input shape: (batch_size, time_step_size, input_vec_size)\n",
    "    XT = tf.transpose(X, [1, 0, 2])  # permute time_step_size and batch_size\n",
    "    # XT shape: (time_step_size, batch_size, input_vec_size)\n",
    "    XR = tf.reshape(XT, [-1, lstm_size]) # each row has input for each lstm cell (lstm_size=input_vec_size)\n",
    "    # XR shape: (time_step_size * batch_size, input_vec_size)\n",
    "    X_split = tf.split(XR, time_step_size, 0) # split them to time_step_size (28 arrays)\n",
    "    # Each array shape: (batch_size, input_vec_size)\n",
    "\n",
    "    # Make lstm with lstm_size (each input vector size)\n",
    "    lstm = rnn.BasicLSTMCell(lstm_size, forget_bias=1.0, state_is_tuple=True)\n",
    "\n",
    "    # Get lstm cell output, time_step_size (28) arrays with lstm_size output: (batch_size, lstm_size)\n",
    "    outputs, _states = rnn.static_rnn(lstm, X_split, dtype=tf.float32)\n",
    "\n",
    "    # Linear activation\n",
    "    # Get the last output\n",
    "    return tf.matmul(outputs[-1], W) + B, lstm.state_size # State size to initialize the stat\n",
    "\n",
    "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)\n",
    "trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels\n",
    "trX = trX.reshape(-1, 28, 28)\n",
    "teX = teX.reshape(-1, 28, 28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "X = tf.placeholder(\"float\", [None, 28, 28])\n",
    "Y = tf.placeholder(\"float\", [None, 10])\n",
    "\n",
    "# get lstm_size and output 10 labels\n",
    "W = init_weights([lstm_size, 10])\n",
    "B = init_weights([10])\n",
    "\n",
    "py_x, state_size = model(X, W, B, lstm_size)\n",
    "\n",
    "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))\n",
    "train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)\n",
    "predict_op = tf.argmax(py_x, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Launch the graph in a session\n",
    "with tf.Session() as sess:\n",
    "    # you need to initialize all variables\n",
    "    tf.global_variables_initializer().run()\n",
    "\n",
    "    for i in range(100):\n",
    "        for start, end in zip(range(0, len(trX), batch_size), range(batch_size, len(trX)+1, batch_size)):\n",
    "            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})\n",
    "\n",
    "        test_indices = np.arange(len(teX))  # Get A Test Batch\n",
    "        np.random.shuffle(test_indices)\n",
    "        test_indices = test_indices[0:test_size]\n",
    "\n",
    "        print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==\n",
    "                         sess.run(predict_op, feed_dict={X: teX[test_indices]})))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
